# agnostic to objective function or data splits


# grid_partition -----------------

#' Grid Partition
#' 
#' A \code{\link{grid_partition}} is defines a grid over a feature-space. It can be built by composing \code{\link{partition_split}}s. 
#' 
#' The partition is typically built by a search algorithm as \code{\link{fit_partition}}.
#' @name GridPartition
NULL
#> NULL

#' Create a null \code{grid_partition}
#' 
#' Create a empty partition. Splits can be added using \code{\link{add_partition_split}}.
#' Information about a split can be retrieved using \code{\link{num_cells}}, \code{\link{get_desc_df}} and \code{\link{print}}
#' With data, one can determine the cell for each observation using \code{\link{predict}}
#'
#' @param X_range Such as from \code{\link{get_X_range}}
#' @param varnames Names of the X-variables
#'
#' @return Grid Partition
#' @export
grid_partition <- function(X_range, varnames=NULL) {
  K = length(X_range)
  s_by_dim = vector("list", length=K) #splits_by_dim(s_seq) #stores Xk_val's
  dim_cat = c()
  for (k in 1:K) {
    if(mode(X_range[[k]])=="character") { 
      dim_cat = c(dim_cat, k) 
      s_by_dim[[k]] = list()
    }
    else { 
      s_by_dim[[k]] = vector("numeric")
    }
  }
  nsplits_by_dim = rep(0, K)
  
  return(structure(list(s_by_dim = s_by_dim, nsplits_by_dim = nsplits_by_dim, varnames=varnames, dim_cat=dim_cat, 
                        X_range=X_range), class = c("grid_partition")))  
}

#' Is grid_partition
#' 
#' Test whether an object is an \code{grid_function}
#'
#' @param x an R object
#'
#' @return True if x is a grid_partition
#' @export
#' @describeIn grid_partition is grid_partition
is_grid_partition <- function(x) {
  inherits(x, "grid_partition")
} 


#' Get X_range
#' 
#' Gets the "range" of each variable in X. For numeric variables this is (min, max).
#' For factors this means vector of levels.  
#'
#' @param X data
#'
#' @return list of length K with each element being the "range" along that dimension
#' @export
get_X_range <- function(X) {
  if(is_sep_sample(X))
    X = do.call("rbind", X)
  if(is.matrix(X)) {
    are_equal(mode(X), "numeric")
  }
  else {
    assert_that(is.data.frame(X), msg="X is not a matrix or data.frame")
    if(inherits(X, "tbl")) X = as.data.frame(X) #tibble's return tibble (rather than vector) for X[,k], making is.factor(X[,k]) and others fail. Could switch to doing X[[k]] for df-like objects
    for(k in seq_len(ncol(X))) are_equal(mode(X[[k]]), "numeric")
  }
  assert_that(ncol(X)>=1, msg="X has no columns")
  
  X_range = list()
  K = ncol(X)
  for(k in 1:K) {
    X_k = X[, k]
    X_range[[k]] = if(is.factor(X_k)) levels(X_k) else range(X_k) #c(min, max)
  }
  return(X_range)
}


#' Get factor describing cell number fo each observation
#' 
#' Note that currently if X has values more extreme (e.g., for numeric or factor levels ) than was used to generate the partition
#' then we will return NA unless you provide and updated X_range.
#'
#' @param object partition
#' @param X X data or list of X
#' @param X_range (Optional) overrides the partition$X_range
#' @param ... Additional arguments. Unused.
#'
#' @return Factor
#' @export
predict.grid_partition <- function(object, X, X_range=NULL, ...) {
  facts = get_factors_from_partition(object, X, X_range=X_range)
  return(interaction_m(facts, is_sep_sample(X)))
}


#' @describeIn num_cells grid_partition
#' @export
num_cells.grid_partition <- function(obj) {
  return(prod(obj$nsplits_by_dim+1))
}

#' Print grid_partition
#' 
#' Prints a data.frame with options
#'
#' @param x partition object
#' @param do_str If True, use a string like "(a, b]", otherwise have two separate columns with a and b
#' @param drop_unsplit If True, drop columns for variables overwhich the partition did not split
#' @param digits digits Option
#' @param ... Additional arguments. Passed to data.frame
#'
#' @return string (and displayed)
#' @export
print.grid_partition <- function(x, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, ...) {
  #To check: digits
  assert_that(is.flag(do_str), is.flag(drop_unsplit), msg="One of do_str or drop_unsplit are not flags")
  return(print(get_desc_df(x, do_str=do_str, drop_unsplit=drop_unsplit, digits=digits), 
               digits=digits, ...))
}


#' Get descriptive data.frame
#'
#' Get information for each cell
#'
#' @inheritParams get_desc_df
#' 
#' 
#' @return data.frame with columns: partitioning columns
#' @export
get_desc_df.grid_partition <- function(obj, cont_bounds_inf=TRUE, do_str=FALSE, drop_unsplit=FALSE, 
                                       digits=NULL, unsplit_cat_star=TRUE, ...) {
  #To check: digits
  assert_that(is.flag(cont_bounds_inf), is.flag(do_str), is.flag(drop_unsplit), is.flag(unsplit_cat_star), msg="One (cont_bounds_inf, do_str, drop_unsplit, unsplit_cat_star)of are not flags.")
  # A split at x_k means that we split to those <= and >
  
  n_segs = obj$nsplits_by_dim+1
  n_cells = prod(n_segs)
  
  if(n_cells==1 & drop_unsplit) return(as.data.frame(matrix(NA, nrow=1, ncol=0)))
  
  #Old code
  #library(tidyverse)
  #desc_df = data.frame(labels=levels(grid_fit$cell_stats$cell_factor), 
  #                     stringsAsFactors = FALSE) %>% separate(labels, names(X), "(?<=]).(?=[(])", PERL=TRUE)
  
  K = length(obj$nsplits_by_dim)
  X_range = obj$X_range
  if(cont_bounds_inf) {
    for(k in 1:K) {
      if(!k %in% obj$dim_cat) X_range[[k]] = c(-Inf, Inf)
    }
  }
  colnames=obj$varnames
  if(is.null(colnames)) colnames = paste("X", 1:K, sep="")
  
  list_of_windows = list()
  for(k in 1:K) {
    list_of_windows[[k]] = if(k %in% obj$dim_cat) get_windows_cat(obj$s_by_dim[[k]], X_range[[k]]) else get_window_cont(obj$s_by_dim[[k]], X_range[[k]])
  }
  
  format_cell_cat <- function(win, unsplit_cat_star, n_tot_dim, sep=", ") {
    if(unsplit_cat_star && n_tot_dim==1) return("*")
    return(paste(win, collapse=sep))
  }
  format_cell_cont <- function(win) {
    if(is.infinite(win[1]) && is.infinite(win[2])) return("*")
    if(is.infinite(win[1])) return(paste0("<=", format(win[2], digits=digits)))
    if(is.infinite(win[2])) return(paste0(">", format(win[1], digits=digits)))
    return(paste0("(", format(win[1], digits=digits), ", ", format(win[2], digits=digits), "]"))
  }
  
  raw_data = data.frame(row.names=1:n_cells)
  str_data = data.frame(row.names=1:n_cells)
  for(k in 1:K) {
    raw_data_k = list()
    str_data_k = c()
    for(cell_i in 1:n_cells) {
      segment_indexes = segment_indexes_from_cell_i(cell_i, n_segs)
      win = list_of_windows[[k]][[segment_indexes[k]]]
      raw_data_k[[cell_i]] = win
      str_data_k[cell_i] = if(k %in% obj$dim_cat) format_cell_cat(win, unsplit_cat_star, length(list_of_windows[[k]])) else format_cell_cont(win)
    }
    raw_data[[colnames[k]]] = cbind(raw_data_k) #make a list-column: https://stackoverflow.com/a/51308306
    str_data[[colnames[k]]] = factor(str_data_k, levels=unique(str_data_k)) #will be in low-high order
  }
  desc_df = if(do_str) str_data else raw_data
  if(drop_unsplit) desc_df = desc_df[n_segs>1]
  
  
  return(desc_df)
}

#' Adds partition_split to grid_partition
#' 
#' Update the partition with an additional split.
#'
#' @param obj Grid Partition object
#' @param s Partition Split object
#'
#' @return updated Grid Partition
#' @export
add_partition_split <- function(obj, s) {
  k = s[[1]]
  X_k_cut = s[[2]]
  
  if(k %in% obj$dim_cat) obj$s_by_dim[[k]][[obj$nsplits_by_dim[k]+1]] = X_k_cut
  else obj$s_by_dim[[k]] = sort(c(X_k_cut, obj$s_by_dim[[k]]))
  obj$nsplits_by_dim[k] = obj$nsplits_by_dim[k]+1
  
  return(obj)
}

get_factors_from_splits_dim <- function(X_k, X_k_range, s_by_dim_k) {
  if(mode(X_k_range)=="character") {
    windows = get_windows_cat(s_by_dim_k, X_k_range)
    fac = X_k
    new_name_map = levels(fac)
    new_names = c()
    for(window in windows) {
      new_name = if(length(window)>1) paste0("{", paste(window, collapse=","), "}") else window[1]
      new_name_map[levels(fac) %in% window] = new_name
      new_names = c(new_names, new_name)
    }
    levels(fac) <- new_name_map
    fac = factor(fac, new_names)
  }
  else {
    bottom_break = X_k_range[1]
    top_break = X_k_range[2]
    #if(nsplits_by_dim_k>0) {
    #bottom_split = s_by_dim_k[1]
    #if(bottom_split==bottom_break)
    bottom_break = bottom_break-1 #not needed
    #}
    top_break = top_break+1
    breaks = c(bottom_break, s_by_dim_k, top_break)
    fac = cut(X_k, breaks, labels=NULL, include.lower=TRUE) #right=FALSE makes [a,b) segments. labels=FALSE makes just numeric vector
  }
  return(fac)
}

get_factors_from_splits_dim_m <- function(X, X_k_range, s_by_dim_k, k) {
  M_mult = is_sep_sample(X)
  if(!M_mult)
    return(get_factors_from_splits_dim(X[,k], X_k_range, s_by_dim_k))
  return(lapply(X, function(X_s) get_factors_from_splits_dim(X_s[,k], X_k_range, s_by_dim_k)))
}

# for a continuous variables, splits are just values
# for a factor variable, a split is a vector of levels (strings)

dummy_X_range <- function(K) {
  X_range = list()
  for(k in 1:K) {
    X_range[[k]] = c(-Inf, Inf)
  }
  return(X_range)
}

#First element is most insignificant (fastest changing), rather than lexicographic
#cell_i and return value are 1-indexed
segment_indexes_from_cell_i <- function(cell_i, n_segments) {
  K = length(n_segments)
  size = cumprod(n_segments)
  if(cell_i > size[K])
    print("Error: too big")
  index = rep(0, K)
  cell_i_rem = cell_i-1 #convert to 0-indexing
  for(k in 1:K) {
    index[k] = cell_i_rem %% n_segments[k]
    cell_i_rem = cell_i_rem %/% n_segments[k]
  }
  index = index+1 #convert from 0-indexing
  return(index)
}

partition_from_split_seq <- function(split_seq, X_range, varnames=NULL, max_include=Inf) {
  part = grid_partition(X_range, varnames)
  for(i in seq_len(min(length(split_seq), max_include))) part = add_partition_split(part, split_seq[[i]])
  return(part)
}

get_factors_from_partition <- function(partition, X, X_range=NULL) {
  X_range = if(is.null(X_range)) partition$X_range else X_range
  factors_by_dim = list()
  if(is_sep_sample(X)) {
    K = ncol(X[[1]])
    for(m in 1:length(X)) {
      factors_by_dim_m = list()
      for(k in 1:K) {
        factors_by_dim_m[[k]] = get_factors_from_splits_dim(X[[m]][, k], X_range[[k]], partition$s_by_dim[[k]])
      }
      factors_by_dim[[m]] = factors_by_dim_m
    }
  }
  else {
    K = ncol(X)
    for(k in 1:K) {
      factors_by_dim[[k]] = get_factors_from_splits_dim(X[, k], X_range[[k]], partition$s_by_dim[[k]])
    }
  }
  return(factors_by_dim)
}

# partition_split ---------------------

#' Create partition_split
#' 
#' Describes a single partition split. Used with \code{\link{add_partition_split}}.
#'
#' @param k dimension
#' @param X_k_cut cut value
#'
#' @return Partition Split
#' @export
partition_split <- function(k, X_k_cut) {
  return(structure(list(k=k, X_k_cut=X_k_cut), class=c("partition_split")))
} 

#' Is \code{partition_split}
#' 
#' Tests whether or not an object is a \code{partition_split}.
#'
#' @param x an R object
#'
#' @return Boolean
#' @export
#' @describeIn partition_split is partition_split
is_partition_split <- function(x){ 
  inherits(x, "partition_split") 
}

#' Print partition_split
#' 
#' Prints information for a \code{partition_split}
#'
#' @param x Object
#' @param ... Additional arguments. Unused.
#'
#' @return None
#' @export
print.partition_split <- function(x, ...) {
  cat(paste0(x[[1]], ": ", x[[2]], "\n"))
}

# Search algo --------------------


#' Fit grid_partition
#' 
#' Fit partition on some data, optionally finding best lambda using CV and then re-fiting on full data.
#' 
#' Returns the partition and information about the fitting process
#'  
#' @section Multiple estimates:
#' With multiple core estimates (M) there are 3 options (the first two have the same sample across treatment effects).\enumerate{
#'  \item DS.MULTI_SAMPLE: Multiple pairs of (Y_{m},W_{m}). y,X,d are then lists of length M. Each element then has the typical size
#'     The N_m may differ across m. The number of columns of X will be the same across m.
#'  \item DS.MULTI_D: Multiple treatments and a single outcome. d is then a NxM matrix.
#'  \item DS.MULTI_Y: A single treatment and multiple outcomes. y is then a NXM matrix.
#' }
#'
#' @param y Nx1 matrix of outcome (label/target) data. With multiple core estimates see Details below.
#' @param X NxK matrix of features (covariates). With multiple core estimates see Details below.
#' @param d (Optional) NxP matrix (with colnames) of treatment data. If all equally important they 
#'          should be normalized to have the same variance. With multiple core estimates see Details below.
#' 
#' @param X_aux aux X sample to compute statistics on (OOS data)
#' @param d_aux aux d sample to compute statistics on (OOS data)
#' @param max_splits Maximum number of splits even if splits continue to improve OOS fit
#' @param max_cells Maximum number of cells even if more splits continue to improve OOS fit
#' @param min_size Minimum cell size when building full grid, cv_tr will use (F-1)/F*min_size, cv_te doesn't use any.
#' @param cv_folds Number of CV Folds or a vector of foldids. 
#'                If m_mode==DS.MULTI_SAMPLE, then a list with foldids per Dataset.
#' @param verbosity 0 print no message. 
#'                  1 prints progress bar for high-level loops. 
#'                  2 prints detailed output for high-level loops. 
#'                  Nested operations decrease verbosity by 1.
#' @param breaks_per_dim NULL (for all possible breaks); 
#'                       K-length vector with # of break (chosen by quantiles); or 
#'                       K-dim list of vectors giving potential split points for non-categorical 
#'                         variables (can put c(0) for categorical). 
#'                      Similar to 'discrete splitting' in CausalTree though their they do separate split-points 
#'                      for treated and controls.
#' @param potential_lambdas potential lambdas to search through in CV
#' @param X_range list of min/max for each dimension (e.g., from \code{\link{get_X_range}})
#' @param bucket_min_n Minimum number of observations needed between different split checks
#' @param bucket_min_d_var Ensure positive variance of d for the observations between different split checks
#' @param obj_fn Default is \code{\link{eval_mse_hat}}. User-provided must allow same signature.
#' @param est_plan \link{EstimatorPlan}.
#' @param partition_i Default NA. Use this to avoid CV
#' @param pr_cl Default NULL. Parallel cluster. Used for:\enumerate{
#'                \item CVing the optimal lambda, 
#'                \item fitting full tree (at each split going across dimensions), 
#'                \item fitting trees over the bumped samples
#'              }
#' @param bump_samples Number of bump bootstraps (default 0), or list of such length where each items is a bootstrap sample.
#'                     If m_mode==DS.MULTI_SAMPLE then each item is a sublist with such bootstrap samples over each dataset.
#' @param bump_ratio For bootstraps the ratio of sample size to sample (between 0 and 1, default 1)
#' @param ... Additional params.
#'
#' @return An object.
#'         \item{partition}{Grid Partition (type=\code{\link{grid_partition}})}
#'         \item{is_obj_val_seq}{Full sequence of in-sample objective function values}
#'         \item{complexity_seq}{Full sequence of partition complexities (num_cells - 1)}
#'         \item{partition_i}{Index of partition chosen}
#'         \item{partition_seq}{Full sequence of Grid Partitions}
#'         \item{split_seq}{Full sequence of splits (type=\code{\link{partition_split}})}
#'         \item{lambda}{lambda chosen}
#'         \item{folds_index_out}{List of the held-out observations for each fold (e.g., we might have generated them)}
#' @export
fit_partition <- function(y, X, d=NULL, X_aux=NULL, d_aux=NULL, max_splits=Inf, max_cells=Inf, 
                          min_size=3, cv_folds=2, verbosity=0, breaks_per_dim=NULL, potential_lambdas=NULL, 
                          X_range=NULL, bucket_min_n=NA, bucket_min_d_var=FALSE, obj_fn, 
                          est_plan, partition_i=NA, pr_cl=NULL, bump_samples=0, bump_ratio=1, ...) {
  #Hidden params:
  # - @param lambda_1se Use the 1se rule to pick the best lambda
  # - @param valid_fn Function to quickly check if partition could be valid. User can override.
  # - @param split_check_fn Alternative split-check function
  # - @param N_est N of samples in the Estimation dataset
  # - @param nsplits_k_warn_limit
  # - @param bump_complexity, method 1 is c(FALSE, FALSE), method 2 is c(FALSE, TRUE), and method 3 is c(TRUE)
  extra_params = list(...)
  valid_fn = split_check_fn = NULL
  lambda_1se=FALSE
  N_est=NA
  nsplits_k_warn_limit=200
  bump_complexity=list(doCV=FALSE, incl_comp_in_pick=FALSE)
  if(length(extra_params)>0) {
    if("valid_fn" %in% names(extra_params)) valid_fn = extra_params[['valid_fn']]
    if("split_check_fn" %in% names(extra_params)) split_check_fn = extra_params[['split_check_fn']]
    if("lambda_1se" %in% names(extra_params)) lambda_1se = extra_params[['lambda_1se']]
    if("N_est" %in% names(extra_params)) N_est = extra_params[['N_est']]
    if("nsplits_k_warn_limit" %in% names(extra_params)) nsplits_k_warn_limit = extra_params[['nsplits_k_warn_limit']]
    if("bump_complexity" %in% names(extra_params)) bump_complexity = extra_params[['bump_complexity']]
    good_args = c("valid_fn", "split_check_fn", "lambda_1se", "N_est","nsplits_k_warn_limit", "bump_complexity")
    bad_names = names(extra_params)[!(names(extra_params) %in% good_args)]
    assert_that(length(bad_names)==0, msg=paste(c(list("Illegal arguments:"), bad_names), collapse = " "))
  }
  
  #To check: y, X, d, N_est, X_aux, d_aux, breaks_per_dim, potential_lambdas, X_range, bucket_min_n
  assert_that(max_splits>0, max_cells>0, min_size>0, msg="max_splits, max_cells, min_size need to be positive")
  assert_that(is.flag(lambda_1se), is.flag(bucket_min_d_var), msg="One of (lambda_1se, bucket_min_d_var) are not flags.") 
  assert_that(inherits(est_plan, "estimator_plan") || (is.list(est_plan) && inherits(est_plan[[1]], "estimator_plan")), msg="estimator_plan argument (or it's first element) doesn't inherit from estimator_plan class") 
  #verbosity can be negative if decrementd from a fit_estimate call
  list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE)
  if(is_sep_sample(X) && length(cv_folds)>1) {
    assert_that(is.list(cv_folds) && length(cv_folds)==M, msg="When separate samples and length(cv_folds)>1, need is.list(cv_folds) && length(cv_folds)==M.")
  }
  check_M_K(M, m_mode, K, X_aux, d_aux)
  do_cv = is.na(partition_i) && (is.null(potential_lambdas) || length(potential_lambdas)>0)
  do_bump = length(bump_samples)>1 || bump_samples > 0
  if(!do_cv) assert_that(bump_complexity$doCV==FALSE, msg="When not doing CV, can't including bumping in CV.")
  if(do_bump && bump_complexity$doCV) {
    if(length(bump_samples==1)) bump_samples = list(bump_samples, bump_samples)
    cv_bump_samples = bump_samples[[1]]
    bump_samples = bump_samples[[2]]
  }
  else cv_bump_samples=0
  
  if(is.null(X_range)) X_range = get_X_range(X)
  if(!is.list(breaks_per_dim) && length(breaks_per_dim)==1) breaks_per_dim = get_quantile_breaks(X, X_range, g=breaks_per_dim)
  if(is.null(valid_fn)) valid_fn = valid_partition
  
  if(is.null(split_check_fn) && (!is.na(bucket_min_n) | bucket_min_d_var)) {
    split_check_fn = purrr::partial(rolling_split_check, bucket_min_n=bucket_min_n, bucket_min_d_var=bucket_min_d_var)
  }
  else{
    split_check_fn = NULL
  }
  
  if(verbosity>0) cat("Grid: Started.\n")
  
  
  if(verbosity>0) cat("Grid: Fitting grid structure on full set\n")
  fit_ret = fit_partition_full(y, X, d, X_aux, d_aux, X_range=X_range, max_splits=max_splits, 
                               max_cells=max_cells, min_size=min_size,  verbosity=verbosity-1, 
                               breaks_per_dim=breaks_per_dim, N_est, split_check_fn=split_check_fn, 
                               obj_fn=obj_fn, allow_empty_aux=FALSE, 
                               allow_est_errors_aux=FALSE, min_size_aux=1, est_plan=est_plan, 
                               pr_cl=pr_cl, valid_fn=valid_fn, nsplits_k_warn_limit=nsplits_k_warn_limit)
  list[partition_seq, is_obj_val_seq, split_seq] = fit_ret
  complexity_seq = sapply(partition_seq, num_cells) - 1
  
  foldids = NA
  if(!is.na(partition_i)) {
    lambda = NA
    max_splits = partition_i-1
    if(length(partition_seq)< partition_i) {
      cat("Note: Couldn't build grid to desired granularity. Using most granular")
      partition_i = length(partition_seq)
    }
    assert_that(bump_complexity$incl_comp_in_pick==FALSE, msg="When no complexity penalization used, can't include complexity cost in bumping calculation.")
  }
  else {
    if(do_cv) {
      list[nfolds, folds_ret, foldids] = expand_fold_info(y, cv_folds, m_mode)
      list[lambda,lambda_oos, n_cell_table] = cv_pick_lambda(y=y, X=X, d=d, folds_ret=folds_ret, nfolds=nfolds, potential_lambdas=potential_lambdas, N_est=N_est, max_splits=max_splits, max_cells=max_cells, 
                              min_size=min_size, verbosity=verbosity, breaks_per_dim=breaks_per_dim, X_range=X_range, lambda_1se=lambda_1se, 
                              split_check_fn=split_check_fn, obj_fn=obj_fn,
                              est_plan=est_plan, pr_cl=pr_cl, valid_fn=valid_fn, cv_bump_samples=cv_bump_samples, bump_ratio=bump_ratio)
    }
    else {
      lambda = potential_lambdas[1]
    }
    partition_i = which.min(is_obj_val_seq + lambda*complexity_seq)
  }

  
  if(do_bump) {
    if(verbosity>0) cat("Grid > Bumping: Started.\n")
    
    if(bump_complexity$incl_comp_in_pick) { 
      best_val = is_obj_val_seq[partition_i] + lambda*complexity_seq[partition_i]
    }
    else {
      best_val = is_obj_val_seq[partition_i]
    }
    
    b_rets = gen_bumped_partitions(bump_samples, bump_ratio, N, m_mode, verbosity, pr_cl, min_size=min_size*bump_ratio, 
                                   y=y, X_d=X, d=d, X_aux=X_aux, d_aux=d_aux, X_range=X_range, max_splits=max_splits, 
                                   max_cells=max_cells,  
                                   breaks_per_dim=breaks_per_dim, N_est=N_est, split_check_fn=split_check_fn, obj_fn=obj_fn, 
                                   min_size_aux=min_size, est_plan=est_plan, 
                                   valid_fn=valid_fn)
    bump_B = length(b_rets)
    
    best_b = NA
    for(b in seq_len(bump_B)) {
      b_ret = b_rets[[b]]
      if(do_cv || bump_complexity$incl_comp_in_pick) b_complexity_seq = sapply(b_ret$partition_seq, num_cells) - 1
      if(do_cv) {
        partition_i_b = which.min(b_ret$is_obj_val_seq + lambda*b_complexity_seq)
      }
      else {
        partition_i_b = partition_i #default
        if(length(b_ret$partition_seq)<partition_i_b) {
          cat("Note: Couldn't build grid to desired granularity. Using most granular\n")
          partition_i_b = length(b_ret$partition_seq)
        }
      }
      partition_b = b_ret$partition_seq[[partition_i_b]]
      
      obj_ret = obj_fn(y, X, d, N_est=N_est, partition=partition_b, est_plan=est_plan, sample="trtr")
      if(obj_ret[2]>0 | obj_ret[3]>0) next #N_cell_empty, N_cell_error
      if(bump_complexity$incl_comp_in_pick) {
        bump_val = obj_ret[1] + lambda*b_complexity_seq[partition_i_b]
      }
      else {
        bump_val = obj_ret[1]
      }
      
      if(bump_val < best_val){
        best_val = bump_val
        best_b = b
      }
    }
    if(!is.na(best_b)) {
      if(verbosity>0) {
        cat(paste("Grid > Bumping: Finished. Picking bumped partition."))
        cat(paste(" Old (unbumped) is_obj_val_seq=[", paste(is_obj_val_seq, collapse=" "), "]."))
        cat(paste(" Old (unbumped) complexity_seq=[", paste(complexity_seq, collapse=" "), "].\n"))
      }
      list[partition_seq, is_obj_val_seq_best_b, split_seq] = b_rets[[best_b]]
      if(do_cv) {
        b_complexity_seq = sapply(b_rets[[best_b]]$partition_seq, num_cells) - 1
        partition_i = which.min(b_rets[[best_b]]$is_obj_val_seq + lambda*b_complexity_seq)
      } 
      complexity_seq = sapply(partition_seq, num_cells) - 1
      is_obj_val_seq = sapply(partition_seq, function(p){
        obj_fn(y, X, d, N_est=N_est, partition=p, est_plan=est_plan, sample="trtr")[1]
      })
    }
    else { 
      if(verbosity>0) cat(paste("Grid > Bumping: Finished. No bumped partitions better than original.\n"))
    }
  }
  
  if(verbosity>0) {
    #print(partition_seq)
    cat(paste("Grid: Finished. is_obj_val_seq=[", paste(is_obj_val_seq, collapse=" "), "]."))
    if(do_cv) {
      cat(paste(" complexity_seq=[", paste(complexity_seq, collapse=" "), "]."))
      cat(paste(" best partition=", paste(partition_i, collapse=" "), "."))
    }
    cat("\n")
  }
  partition = partition_seq[[partition_i]]
  return(list(partition=partition, is_obj_val_seq=is_obj_val_seq, complexity_seq=complexity_seq, 
              partition_i=partition_i, partition_seq=partition_seq, split_seq=split_seq, lambda=lambda, 
              foldids=foldids))
}


#' Get break-points by looking at quantiles
#' 
#' Provides a set of potential split points for data according to quantiles (if possible)
#'
#' @param X Features
#' @param X_range X-range
#' @param g # of quantiles
#' @param type Quantile type (see ?quantile and https://mathworld.wolfram.com/Quantile.html). 
#'             Types1-3 are discrete and this is good for passing to unique() when there are clumps
#'
#' @return list of potential breaks
get_quantile_breaks <- function(X, X_range, g=20, type=3) {
  if(is.null(g)) g=20 #fit_estimate has a different default that might get passed in.
  if(is_sep_sample(X)) X = X[[1]]
  X = ensure_good_X(X)
  
  breaks_per_dim = list()
  K = ncol(X)
  for(k in 1:K) {
    X_k = X[,k]
    if(is.factor(X_k)) {
      breaks_per_dim[[k]] = c(0) #Dummy
    }
    else {
      if(storage.mode(X_k)=="integer" && (X_range[[k]][2]-X_range[[k]][1])<=g) {
        vals = sort(unique(X_k))
        breaks_per_dim[[k]] = vals[-c(length(vals), 1)]
      }
      else {
        #unique(sort(X[,k])) #we will automatically skip the top point
        #if you want g cuts, then there are g+2 outer nodes
        qs = quantile(X_k, seq(0, 1, length.out=g+2), names=FALSE, type=type)
        qs = unique(qs)
        breaks_per_dim[[k]] = qs[-c(length(qs), 1)]
      }
    }
  }
  return(breaks_per_dim)
}

# if d vectors are empty doesn't return fail 
valid_partition <- function(cell_factor, d=NULL, cell_factor_aux=NULL, d_aux=NULL, min_size=0) {
  #check none of the cells are too small
  if(min_size>0) {
    if(length(cell_factor)==0) return(list(fail=TRUE, min_size=0))
    lowest_size = min(table(cell_factor))
    if(lowest_size<min_size) return(list(fail=TRUE, min_size=lowest_size))
    
    if(!is.null(cell_factor_aux)) {
      if(length(cell_factor_aux)==0) return(list(fail=TRUE, min_size_aux=0))
      lowest_size_aux = min(table(cell_factor_aux))
      if(lowest_size_aux<min_size) return(list(fail=TRUE, min_size_aux=lowest_size_aux))
    }
  }
  
  if(!is.null(d)) {
    if(!is_vec(d)) {
      for(m in 1:ncol(d)) {
        if(any(by(d[,m], cell_factor, FUN=const_vect))) {
          return(list(fail=TRUE, always_d_var=FALSE))
        }
      }
      
    }
    else {
      if(any(by(as.vector(d), cell_factor, FUN=const_vect))) {
        return(list(fail=TRUE, always_d_var=FALSE))
      }
    }
  }
  if(!is.null(d_aux)) {
    if(!is_vec(d_aux)) {
      for(m in 1:ncol(d_aux)) {
        if(any(by(d_aux[,m], cell_factor, FUN=const_vect))) {
          return(list(fail=TRUE, always_d_var=FALSE))
        }
      }
      
    }
    else {
      if(any(by(as.vector(d_aux), cell_factor_aux, FUN=const_vect))) {
        return(list(fail=TRUE, always_d_var_aux=FALSE))
      }
    }
  }
  return(list(fail=FALSE))
}


gen_bumped_partitions <- function(bump_samples, bump_ratio, N, m_mode, verbosity, pr_cl, allow_empty_aux=FALSE, allow_est_errors_aux=FALSE, ...) {
  assert_that(bump_ratio>0, bump_ratio<=1, msg="bump_ration needs to be in (0,1]")
  bump_samples = expand_bump_samples(bump_samples, bump_ratio, N, m_mode)
  bump_B = length(bump_samples)
  
  params = c(list(samples=bump_samples, verbosity=verbosity-1, allow_empty_aux=FALSE, allow_est_errors_aux=FALSE, pr_cl=NULL, m_mode=m_mode),
             list(...))
  
  b_rets = my_apply(1:bump_B, fit_partition_bump_b, verbosity==1 || !is.null(pr_cl), pr_cl, params)
  return(b_rets)
}

#if not mid-point then the all but the last are the splits
get_usable_break_points <- function(breaks_per_dim, X, X_range, dim_cat, mid_point=TRUE) {
  if(is_sep_sample(X)) X = X[[1]]
  K = ncol(X)
  #old code
  if(is.null(breaks_per_dim)) {
    breaks_per_dim = list()
    for(k in 1:K) {
      if(!k %in% dim_cat) {
        u = unique(sort(X[, k]))
        if(mid_point) {
          breaks_per_dim[[k]] = u[-length(u)] + diff(u) / 2
        }
        else {
          breaks_per_dim[[k]] = u[-length(u)] #skip last point
        }
      }
      else {
        breaks_per_dim[[k]] = c(0) #Dummy just for place=holder
      }
    }
  }
  else { #make sure they didn't include the lowest point
    for(k in 1:K) {
      if(!k %in% dim_cat) {
        n_k = length(breaks_per_dim[[k]])
        if(breaks_per_dim[[k]][n_k]==X_range[[k]][2]) {
          breaks_per_dim[[k]] = breaks_per_dim[[k]][-n_k]
        }
      }
      breaks_per_dim[[k]] = unname(breaks_per_dim[[k]]) #names messed up the get_desc_df() (though not in debugSource)
    }
  }
  return(breaks_per_dim)
}


#Typically is_obj_val_seq trends negative. If first element is min, then return c()
get_lambda_ties <- function(is_obj_val_seq, complexity_seq) {
  n_seq = length(is_obj_val_seq)
  slopes = c() #will go from strongly negative and increases and we stop before reaching 0
  hull_i = 1
  while(hull_i < n_seq) {
    i_slopes = rep(NA, n_seq)
    for(i in (hull_i+1):n_seq) {
      i_slopes[i] = (is_obj_val_seq[i] - is_obj_val_seq[hull_i])/(complexity_seq[i]- complexity_seq[hull_i])
    }
    best_slope = min(i_slopes, na.rm=TRUE)
    if(best_slope>=0) break
    slopes = c(slopes, best_slope)
    hull_i = which.min(i_slopes)
  }
  if(length(slopes)>1) {
    lambda_ties = abs(slopes) #slightly bigger will go will pick the index earlier, slightly bigger later
  }
  else {
    lambda_ties = c()
  }
  return(lambda_ties)
}

gen_cat_window_splits <- function(chr_vec) {
  n = length(chr_vec)
  splits=list()
  for(m in seq_len(floor(n/2))) {
    cs = combn(chr_vec, m, simplify=F)
    if(m==n/2) cs = cs[1:(length(cs)/2)] #or just filter by those that contain chr_vec[1]
    splits = c(splits, cs)
  }
  return(splits)
}

n_cat_window_splits <- function(window_len) {
  n_splits = 0
  for(m in seq_len(floor(window_len/2))) {
    n_choose = choose(window_len, m)
    n_splits = n_splits + if(m==window_len/2) n_choose/2 else n_choose
  }
  return(n_splits)
}

n_cat_splits <- function(s_by_dim_k, X_range_k) {
  windows = get_windows_cat(s_by_dim_k, X_range_k)
  n_splits = 0
  for(window in windows) n_splits = n_splits + n_cat_window_splits(length(window))
  return(n_splits)
}

get_windows_cat <- function(s_by_dim_k, X_k_range) {
  windows = s_by_dim_k
  windows[[length(windows)+1]] = X_k_range[!X_k_range %in% unlist(c(windows))]
  return(windows)
}

get_window_cont <- function(s_by_dim_k, X_k_range) {
  windows=list()
  n_w = length(s_by_dim_k)+1
  for(w in 1:n_w) {
    wmin = if(w==1) X_k_range[1] else s_by_dim_k[w-1]
    wmax = if(w==n_w) X_k_range[2] else s_by_dim_k[w]
    windows[[w]] = c(wmin, wmax)
  }
  return(windows)
}

gen_holdout_interaction <- function(factors_by_dim, k) {
  if(length(factors_by_dim)>1)
    return(interaction(factors_by_dim[-k]))
  return(factor(rep("|", length(factors_by_dim[[1]]))))
}

n_breaks_k <- function(breaks_per_dim, k, partition, X_range) {
  if(k %in% partition$dim_cat) return(n_cat_splits(partition$s_by_dim[[k]], X_range[[k]]))
  return(length(breaks_per_dim[[k]]))
}


rolling_split_check <- function(shifted_N, shifted_d=NULL, shifted_cell_factor_nk, m_mode, bucket_min_n=NA, bucket_min_d_var=FALSE) {
  if(!is.na(bucket_min_n) && min(shifted_N)<bucket_min_n){
    #cat("Skipped: increment not big enough\n")
    return(FALSE)
  } 
  
  if(bucket_min_d_var && !is.null(shifted_d) && any_const_m(shifted_d, shifted_cell_factor_nk, m_mode)) {
    return(FALSE)
  }
  
  return(TRUE)
}

fit_partition_full_k <- function(k, y, X_d, d, X_range, pb, debug, valid_breaks, factors_by_dim, X_aux, 
                                 factors_by_dim_aux, partition, verbosity, allow_empty_aux=TRUE, d_aux, 
                                 allow_est_errors_aux, min_size, min_size_aux, obj_fn, N_est, est_plan, 
                                 split_check_fn = NULL, breaks_per_dim, valid_fn=NULL) { #, n_cut
  assert_that(is.flag(allow_empty_aux), msg="allow_empty_aux needs to be logical flags.")
  list[M, m_mode, N, K] = get_sample_type(y, X_d, d, checks=FALSE)
  if(is.null(valid_fn)) valid_fn = valid_partition
  search_ret = list()
  best_new_val = Inf
  valid_breaks_k = valid_breaks[[k]]
  cell_factor_nk = gen_holdout_interaction_m(factors_by_dim, k, is_sep_sample(X_d))
  if(!is.null(X_aux)) {
    cell_factor_nk_aux = gen_holdout_interaction_m(factors_by_dim_aux, k, is_sep_sample(X_aux))
  }
  
  if(!is_factor_dim_k_m(X_d, k, m_mode==DS.MULTI_SAMPLE)) {
    n_pot_break_points_k = length(breaks_per_dim[[k]])
    vals = rep(NA, n_pot_break_points_k)
    prev_split_checked = X_range[[k]][1]
    win_LB = X_range[[k]][1]-1
    win_UB = if(length(partition$s_by_dim[[k]])>0) partition$s_by_dim[[k]][1] else X_range[[k]][2]
    win_mask = gen_cont_window_mask_m(X_d, k, win_LB, win_UB)
    win_mask_aux = gen_cont_window_mask_m(X_aux, k, win_LB, win_UB)
    for(X_k_cut_i in seq_len(n_pot_break_points_k)) { #cut-point is top end of segment, 
      if (verbosity>0 && !is.null(pb)) setTxtProgressBar(pb, getTxtProgressBar(pb)+1)
      X_k_cut = breaks_per_dim[[k]][X_k_cut_i]
      if(X_k_cut %in% partition$s_by_dim[[k]]) {
        prev_split_checked = X_k_cut
        win_LB = X_k_cut
        higher_prev_split = partition$s_by_dim[[k]][partition$s_by_dim[[k]]>X_k_cut]
        win_UB = if(length(higher_prev_split)>0) min(higher_prev_split) else X_range[[k]][2]
        win_mask = gen_cont_window_mask_m(X_d, k, win_LB, win_UB)
        win_mask_aux = gen_cont_window_mask_m(X_aux, k, win_LB, win_UB)
        next
      } 
      if(!valid_breaks_k[[1]][X_k_cut_i]) next
      new_split = partition_split(k, X_k_cut)
      tent_partition = add_partition_split(partition, new_split)
      
      tent_split_fac_k = get_factors_from_splits_dim_m(X_d, X_range[[k]], tent_partition$s_by_dim[[k]], k)
      tent_cell_factor = interaction2_m(cell_factor_nk, tent_split_fac_k, m_mode==DS.MULTI_SAMPLE)
      if(!is.null(X_aux)) {
        tent_split_fac_k_aux = get_factors_from_splits_dim_m(X_aux, X_range[[k]], tent_partition$s_by_dim[[k]], k)
        tent_cell_factor_aux = interaction2_m(cell_factor_nk_aux, tent_split_fac_k_aux, is_sep_sample(X_aux))
      }
      
      if(!is.null(split_check_fn)){
        shifted_mask = gen_cont_window_mask_m(X_d, k, prev_split_checked, X_k_cut)
        shifted_N = sum_m(shifted_mask, m_mode==DS.MULTI_SAMPLE)
        shifted_cell_factor_nk = droplevels_m(apply_mask_m(cell_factor_nk, shifted_mask, m_mode==DS.MULTI_SAMPLE), m_mode==DS.MULTI_SAMPLE)
        shifted_d = if(is.null(d)) NULL else apply_mask_m(d, shifted_mask, m_mode==DS.MULTI_SAMPLE)
        split_OK = split_check_fn(shifted_N, shifted_d, shifted_cell_factor_nk, m_mode)
        if(!split_OK) {
          valid_breaks_k[[1]][X_k_cut_i] = FALSE
          next
        }
      }
      
      # do_window_approach
      #The bucket checks don't help much. 
      #- Though I do check for non-zero var of D, that's just on the left so to check on right side too
      #- Note that though not min_size as different than bucket_min_n)
      win_split_cond = gen_cont_win_split_cond_m(X_d, win_mask, k, X_k_cut)
      win_cell_factor_nk = apply_mask_m(cell_factor_nk, win_mask, m_mode==DS.MULTI_SAMPLE)
      win_cell_factor = interaction2_m(win_cell_factor_nk, win_split_cond, m_mode==DS.MULTI_SAMPLE)
      win_d = apply_mask_m(d, win_mask, m_mode==DS.MULTI_SAMPLE)
      valid_ret = valid_partition_m(m_mode==DS.MULTI_SAMPLE, valid_fn, win_cell_factor, d=win_d, min_size=min_size)
      if(!valid_ret$fail) {
        if(!allow_empty_aux && !is.null(X_aux)) {
          win_split_cond_aux = gen_cont_win_split_cond_m(X_aux, win_mask_aux, k, X_k_cut)
          win_cell_factor_aux = interaction2_m(apply_mask_m(cell_factor_nk_aux, win_mask_aux, is_sep_sample(X_aux)), 
                                               win_split_cond_aux, is_sep_sample(X_aux), drop=allow_empty_aux)
          win_d_aux = if(!allow_est_errors_aux) apply_mask_m(d_aux, win_mask_aux, is_sep_sample(X_aux)) else NULL
          valid_ret = valid_partition_m(is_sep_sample(X_aux), valid_fn, win_cell_factor_aux, d=win_d_aux, min_size=min_size_aux)
        }
      }
      # Global approach
      # valid_ret = valid_fn(tent_cell_factor, d=d, min_size=min_size)
      # if(!valid_ret$fail) {
      #   valid_ret = valid_fn(tent_cell_factor_aux, d=d_aux, min_size=2)
      # }
      if(valid_ret$fail) {
        #cat("Invalid partition\n")
        valid_breaks_k[[1]][X_k_cut_i] = FALSE
        next
      }
      if(debug) cat(paste("k", k, ". X_k", X_k_cut, "\n"))
      obj_ret = obj_fn(y, X_d, d, N_est=N_est, cell_factor_tr = tent_cell_factor, debug=debug, est_plan=est_plan, 
                       sample="trtr")
      if(obj_ret[3]>0) { #don't need to check [2] (empty cells) as we already did that
        #cat("Estimation errors\n")
        valid_breaks_k[[1]][X_k_cut_i] = FALSE
        next
      }
      val = obj_ret[1]
      stopifnot(is.finite(val))
      prev_split_checked = X_k_cut
      
      if(val<best_new_val) {
        #if(verbosity>0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val))
        best_new_val = val
        new_factors_by_dim = replace_k_factor_m(factors_by_dim, k, tent_split_fac_k, is_sep_sample(X_d))
        if(!is.null(X_aux)) {
          new_factors_by_dim_aux = replace_k_factor_m(factors_by_dim_aux, k, tent_split_fac_k_aux, is_sep_sample(X_aux))
        }
        else new_factors_by_dim_aux = NULL
        search_ret = list(val=val, new_split=new_split, new_factors_by_dim=new_factors_by_dim, 
                          new_factors_by_dim_aux=new_factors_by_dim_aux)
      }
    }
    
  }
  else { #categorical variable
    windows = get_windows_cat(partition$s_by_dim[[k]], X_range[[k]])
    for(window_i in seq_len(length(windows))) {
      window = windows[[window_i]]
      win_mask = gen_cat_window_mask_m(X_d, k, window)
      win_mask_aux = gen_cat_window_mask_m(X_aux, k, window)
      pot_splits = gen_cat_window_splits(window)
      for(win_split_i in seq_len(length(pot_splits))) {
        win_split_val = pot_splits[[win_split_i]]
        #TODO: Refactor with continuous case
        if (verbosity>0 && !is.null(pb)) setTxtProgressBar(pb, getTxtProgressBar(pb)+1)
        if(!valid_breaks_k[[window_i]][win_split_i]) next
        
        new_split = partition_split(k, win_split_val)
        tent_partition = add_partition_split(partition, new_split)
        
        tent_split_fac_k = get_factors_from_splits_dim_m(X_d, X_range[[k]], tent_partition$s_by_dim[[k]], k)
        tent_cell_factor = interaction2_m(cell_factor_nk, tent_split_fac_k, m_mode==DS.MULTI_SAMPLE)
        if(!is.null(X_aux)) {
          tent_split_fac_k_aux = get_factors_from_splits_dim_m(X_aux, X_range[[k]], tent_partition$s_by_dim[[k]], k)
          tent_cell_factor_aux = interaction2_m(cell_factor_nk_aux, tent_split_fac_k_aux, is_sep_sample(X_aux))
        }
        
        # do_window_approach
        win_split_cond = gen_cat_win_split_cond_m(X_d, win_mask, k, win_split_val)
        win_cell_factor = interaction2_m(apply_mask_m(cell_factor_nk, win_mask, m_mode==DS.MULTI_SAMPLE), win_split_cond, m_mode==DS.MULTI_SAMPLE)
        win_d = apply_mask_m(d, win_mask, m_mode==DS.MULTI_SAMPLE)
        valid_ret = valid_partition_m(m_mode==DS.MULTI_SAMPLE, valid_fn, win_cell_factor, d=win_d, min_size=min_size)
        if(!valid_ret$fail) {
          if(!is.null(X_aux) && !allow_empty_aux) {
            win_split_cond_aux = factor(gen_cat_win_split_cond_m(X_aux, win_mask_aux, k, win_split_val), levels=c(FALSE, TRUE))
            win_cell_factor_aux = interaction2_m(apply_mask_m(cell_factor_nk_aux, win_mask_aux, is_sep_sample(X_aux)), 
                                                 win_split_cond_aux, is_sep_sample(X_aux), drop=allow_empty_aux)
            win_d_aux = if(!allow_est_errors_aux) apply_mask_m(d_aux, win_mask_aux, is_sep_sample(X_aux)) else NULL
            valid_ret = valid_partition_m(is_sep_sample(X_aux), valid_fn, win_cell_factor_aux, d=win_d_aux, min_size=min_size_aux)
          }
        }
        if(valid_ret$fail) {
          #cat("Invalid partition\n")
          valid_breaks_k[[window_i]][win_split_i] = FALSE
          next
        }
        if(debug) cat(paste("k", k, ". X_k", win_split_val, "\n"))
        obj_ret = obj_fn(y, X_d, d, N_est=N_est, cell_factor_tr = tent_cell_factor, debug=debug, est_plan=est_plan, 
                         sample="trtr")
        if(obj_ret[3]>0) { #don't need to check [2] (empty cells) as we already did that
          #cat("Estimation errors\n")
          valid_breaks_k[[window_i]][win_split_i] = FALSE
          next
        }
        val = obj_ret[1]
        stopifnot(is.finite(val))
        
        
        if(val<best_new_val) {
          #if(verbosity>0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val))
          best_new_val = val
          new_factors_by_dim = replace_k_factor_m(factors_by_dim, k, tent_split_fac_k, is_sep_sample(X_d))
          if(!is.null(X_aux)) {
            new_factors_by_dim_aux = replace_k_factor_m(factors_by_dim_aux, k, tent_split_fac_k_aux, is_sep_sample(X_aux))
          }
          else new_factors_by_dim_aux = NULL
          search_ret = list(val=val, new_split=new_split, new_factors_by_dim=new_factors_by_dim, 
                            new_factors_by_dim_aux=new_factors_by_dim_aux)
        }
        
      }
    }
  }

  return(list(search_ret, valid_breaks_k))
}

# There are three general problems with a partition. 
# 1) Empty cells
# 2) Non-empty cells where objective can't be calculated
# 3) Cells where it can be calulcated but due to small sizes we don't want
# Main sample: We assume that a valid partition removes #1 and #2. Use min_size for 3
# For Aux: Use allow_empty_aux, allow_est_errors_aux, min_size_aux
# Include d_aux if you want to make sure that non-empty cells in aux have positive variance in d
# FOr CV:allow_empty_aux=TRUE, allow_est_errors_aux=FALSE, min_size_aux=1 (weaker check than removing estimation errors)
# Can set nsplits_k_warn_limit=Inf to disable
fit_partition_full <- function(y, X, d=NULL, X_aux=NULL, d_aux=NULL, X_range, max_splits=Inf, max_cells=Inf, 
                               min_size=2, verbosity=0, breaks_per_dim, N_est, obj_fn, allow_est_errors_aux=TRUE, 
                               min_size_aux=2, est_plan, partition=NULL, nsplits_k_warn_limit=200, pr_cl=NULL, 
                               ...) {
  assert_that(max_splits>=0, max_cells>=1, min_size>=1, msg="Need max_splits>=0, max_cells>=1, min_size>=1.")
  assert_that(is.flag(allow_est_errors_aux), msg="allow_est_errors_aux needs to be a flag")
  assert_that(is.na(nsplits_k_warn_limit) || nsplits_k_warn_limit>=1, msg="nsplits_k_warn_limit not understood")
  list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE)
  est_min = ifelse(is.null(d), 2, 3) #If don't always need variance calc: ifelse(is.null(d), ifelse(honest, 2, 1), ifelse(honest, 3, 2))
  min_size = max(min_size, est_min)
  if(!allow_est_errors_aux)  min_size_aux = max(min_size_aux, est_min)
  debug = FALSE
  if(is.null(partition)) partition = grid_partition(X_range, colnames(X))
  breaks_per_dim = get_usable_break_points(breaks_per_dim, X, X_range, partition$dim_cat)
  valid_breaks = vector("list", length=K) #splits_by_dim(s_seq) #stores Xk_val's
  
  for(k in 1:K) {
    n_split_breaks_k = n_breaks_k(breaks_per_dim, k, partition, X_range)
    valid_breaks[[k]] = list(rep(TRUE, n_split_breaks_k))
    if(!is.na(nsplits_k_warn_limit) && n_split_breaks_k>nsplits_k_warn_limit) warning(paste("Warning: Many splits (", n_split_breaks_k, ") along dimension", k, "\n"))
  }
  factors_by_dim = get_factors_from_partition(partition, X)
  if(!is.null(X_aux)) {
    factors_by_dim_aux = get_factors_from_partition(partition, X_aux)
  }
  
  if(verbosity>0){
    cat("Grid > Fitting: Started.\n")
    t0 = Sys.time()
  } 
  split_i = 1
  seq_val = c()
  obj_ret = obj_fn(y, X, d, N_est=N_est, cell_factor_tr = interaction_m(factors_by_dim, is_sep_sample(X)), est_plan=est_plan, sample="trtr")
  if(obj_ret[3]>0 || !is.finite(obj_ret[1])) {
    stop("Estimation error with initial partition")
  }
  seq_val[1] = obj_ret[1]
  partition_seq = list()
  split_seq = list()
  partition_seq[[1]] = partition
  tent_cell_factor_aux = NULL
  style = if(summary(stdout())$class=="terminal") 3 else 1
  if(!is.null(pr_cl) & !requireNamespace("parallel", quietly = TRUE)) {
    stop("Package \"parallel\" needed for this function to work. Please install it.", call. = FALSE)
  }
  do_pbapply = requireNamespace("pbapply", quietly = TRUE) & (verbosity>0) & (is.null(pr_cl) || length(pr_cl)<K)
  
  while(TRUE) {
    if(split_i>max_splits) break
    if(num_cells(partition)==max_cells) break
    n_cuts_k = rep(0, K)
    for(k in 1:K) {
      n_cuts_k[k] = n_breaks_k(breaks_per_dim, k, partition, X_range)
    }
    n_cuts_total = sum(n_cuts_k)
    if(n_cuts_total==0) break
    if(verbosity>0) {
      cat(paste("Grid > Fitting > split ", split_i, ": Started\n"))
      t1 = Sys.time()
      if(is.null(pr_cl)) pb = txtProgressBar(0, n_cuts_total, style = style)
    }
    
    params = c(list(y=y, X_d=X, d=d, X_range=X_range, pb=NULL, debug=debug, valid_breaks=valid_breaks, 
                  factors_by_dim=factors_by_dim, X_aux=X_aux, factors_by_dim_aux=factors_by_dim_aux, partition=partition, 
                  verbosity=verbosity, d_aux=d_aux, allow_est_errors_aux=allow_est_errors_aux, 
                  min_size=min_size, min_size_aux=min_size_aux, obj_fn=obj_fn, N_est=N_est, est_plan=est_plan, 
                  breaks_per_dim=breaks_per_dim), list(...))
    
    col_rets = my_apply(1:K, fit_partition_full_k, verbosity, pr_cl, params)
    
    best_new_val = Inf
    best_new_split = NULL
    for(k in 1:K) {
      col_ret = col_rets[[k]]
      search_ret = col_ret[[1]]
      valid_breaks[[k]] = col_ret[[2]]
      if(length(search_ret)>0 && search_ret$val<best_new_val) {
        #if(verbosity>0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val))
        best_new_val = search_ret$val
        best_new_split = search_ret$new_split
        best_new_factors_by_dim = search_ret$new_factors_by_dim
        if(!is.null(X_aux)) {
          best_new_factors_by_dim_aux = search_ret$new_factors_by_dim_aux
        }
      }
    }
    
    if (verbosity>0) {
      t2 = Sys.time() #can us as.numeric(t1) to convert to seconds
      td = t2-t1
      if(is.null(pr_cl)) close(pb)
    }
    if(is.null(best_new_split)) {
      if (verbosity>0) cat(paste("Grid > Fitting > split ", split_i, ": Finished. Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ". No valid splits\n"))
      break
    } 
    best_new_partition = add_partition_split(partition, best_new_split)
    if(num_cells(best_new_partition)>max_cells) {
      if (verbosity>0) cat(paste("Grid > Fitting > split ", split_i, ": Finished. Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ". Best split has results in too many cells\n"))
      break
    } 
    partition = best_new_partition
    factors_by_dim = best_new_factors_by_dim
    if(!is.null(X_aux)) {
      factors_by_dim_aux = best_new_factors_by_dim_aux
    }
    split_i = split_i + 1
    seq_val[split_i] = best_new_val
    partition_seq[[split_i]] = partition
    split_seq[[split_i-1]] = best_new_split
    if(best_new_split[[1]] %in% partition$dim_cat) {
      k = best_new_split[[1]]
      windows = get_windows_cat(partition$s_by_dim[[k]], X_range[[k]])
      nwindows = length(windows)
      v_breaks = vector("list", length=nwindows)
      for(window_i in seq_len(nwindows)) {
        v_breaks[[window_i]] = rep(TRUE, n_cat_window_splits(length(windows[[window_i]])))
      }
      valid_breaks[[k]] = v_breaks
    }
    if (verbosity>0) { 
      cat(paste("Grid > Fitting > split ", split_i, ": Finished.",
                " Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ".",
                " New split: k=", best_new_split[[1]], ", cut=", best_new_split[[2]], ", val=", best_new_val, "\n"))
    }
  }
  if (verbosity>0) {
    tn = Sys.time()
    td = tn-t0
    cat("Grid > Fitting: Finished.")
    cat(paste(" Entire Search Duration: ", format(as.numeric(td)), " ", attr(td, "units"), "\n"))
  }
  return(list(partition_seq=partition_seq, is_obj_val_seq=seq_val, split_seq=split_seq))
}

#Allows two lists or two datasets 
add_samples <- function (X, X_aux, M_mult) {
  if(is.null(X_aux)) return(X)
  if(M_mult) return(c(X, X_aux))
  return(list(X, X_aux))
}


fit_partition_bump_b <- function(b, samples, y, X_d, d=NULL, m_mode, X_aux, d_aux, verbosity, nsplits_k_warn_limit=NA, ...){
  if(verbosity>0) cat(paste("Grid > Bumping > b = ", b, "\n"))
  sample = samples[[b]]
  list[y_b, X_b, d_b] = subsample_m(y, X_d, d, sample)
  X_aux2 = add_samples(X_d, X_aux, is_sep_sample(X_d))
  d_aux2 = add_samples(d, d_aux, is_sep_sample(X_d))
  fit_partition_full(y=y_b, X=X_b, d=d_b, X_aux=X_aux2, d_aux=d_aux2, verbosity=verbosity, nsplits_k_warn_limit=NA, ...)
}

# These are bump wrappers
get_part_for_lambda <- function(obj, lambda, is_bumped=FALSE) {
  if(is_bumped) {
    is_obj_val_seq = unlist(lapply(obj, function(f) f$is_obj_val_seq))
    complexity_seq = unlist(lapply(obj, function(f) sapply(f$partition_seq, num_cells) - 1))
    partition_seq = unlist(lapply(obj, function(f) f$partition_seq ), recursive = FALSE)
  }
  else {
    is_obj_val_seq = obj$is_obj_val_seq
    complexity_seq = sapply(obj$partition_seq, num_cells) - 1
    partition_seq = obj$partition_seq
  }
  partition_i = which.min(is_obj_val_seq + lambda*complexity_seq)
  return(list(partition_i, partition_seq[[partition_i]]))
}
get_num_parts <- function(cvtr_fit, is_bumped=FALSE) {
  if(is_bumped)
    return(sum(sapply(cvtr_fit, function(part) length(part$is_obj_val_seq))))
  return(length(cvtr_fit$is_obj_val_seq))
}

get_all_lambda_ties <- function(cvtr_fit, is_bumped=FALSE) {
  if(is_bumped) {
    return(unlist(lapply(cvtr_fit, function(f) get_lambda_ties(f$is_obj_val_seq, sapply(f$partition_seq, num_cells) - 1))))
  }
  return(get_lambda_ties(cvtr_fit$is_obj_val_seq, sapply(cvtr_fit$partition_seq, num_cells) - 1))
}

# ... params sent to fit_partition_full()
cv_pick_lambda_f <- function(f, y, X_d, d, folds_ret, nfolds, potential_lambdas, N_est, 
                             verbosity, obj_fn, cv_tr_min_size, est_plan, cv_bump_samples, bump_ratio, 
                             nsplits_k_warn_limit=NA, min_size_aux=1, allow_empty_aux=TRUE, allow_est_errors_aux=FALSE, recal_is_obj_b=TRUE, ...) { #catch some of the params that might still be in ...
  if(verbosity>0) cat(paste("Grid > CV > Fold", f, "\n"))
  supplied_lambda = !is.null(potential_lambdas)
  if(supplied_lambda) n_lambda = length(potential_lambdas)
  
  list[y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv] = split_sample_folds_m(y, X_d, d, folds_ret, f)

  do_bump = (length(cv_bump_samples)>1 || cv_bump_samples>0)
  cvtr_fit = fit_partition_full(y_f_tr, X_f_tr, d_f_tr, X_f_cv, d_f_cv, 
                                min_size=cv_tr_min_size, verbosity=verbosity, 
                                N_est=N_est, obj_fn=obj_fn, allow_empty_aux=TRUE, 
                                allow_est_errors_aux=FALSE, min_size_aux=1, est_plan=est_plan, 
                                nsplits_k_warn_limit=NA, ...) #min_size_aux is weaker than removing est errors
  
  if(do_bump) {
    if(length(cv_bump_samples)>1) cv_bump_samples = cv_bump_samples[[f]]
    list[M, m_mode, N_tr, K] = get_sample_type(y_f_tr, X_f_tr, d_f_tr, checks=FALSE)
    cvtr_fit_bumps = gen_bumped_partitions(bump_samples=cv_bump_samples, bump_ratio, N_tr, m_mode, verbosity=verbosity, pr_cl=NULL, 
                                     min_size=cv_tr_min_size*bump_ratio, 
                                     y=y_f_tr, X_d=X_f_tr, d=d_f_tr, X_aux=X_f_cv, d_aux=d_f_cv, 
                                     N_est=N_est, obj_fn=obj_fn, 
                                     min_size_aux=1, est_plan=est_plan, nsplits_k_warn_limit=NA,
                                     allow_empty_aux=allow_empty_aux, allow_est_errors_aux=allow_est_errors_aux,
                                     ...)
    if(recal_is_obj_b) {
      #Use the updated values on the unbumped sample
      for(b in 1:length(cvtr_fit_bumps)) {
        #partition_seq=partition_seq, is_obj_val_seq
        cvtr_fit_bumps[[b]]$is_obj_val_seq = sapply(cvtr_fit_bumps[[b]]$partition_seq, function(p){
          obj_fn(y_f_tr, X_f_tr, d_f_tr, partition=p, est_plan=est_plan, sample="trtr")[1]
        })
      }
    }
    cvtr_fit = c(list(cvtr_fit), cvtr_fit_bumps)
  }
  
  if(!supplied_lambda) {
    return(cvtr_fit)
  } 
  #If we know the lambdas, eval data while we have it
  return(eval_lambdas(obj_fn, est_plan, potential_lambdas, cvtr_fit, y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est, do_bump))
}


eval_lambdas <- function(obj_fn, est_plan, potential_lambdas, cvtr_fit, y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est, is_bumped) {
  partition_oos_cache = rep(NA, get_num_parts(cvtr_fit, is_bumped))
  n_lambda = length(potential_lambdas)
  
  lambda_oos = rep(NA, n_lambda)
  for(lambda_i in seq_len(n_lambda)) {
    lambda = potential_lambdas[lambda_i]
    list[partition_i, part] = get_part_for_lambda(cvtr_fit, lambda, is_bumped)
    if(is.na(partition_oos_cache[partition_i])) {
      debug = FALSE 
      if(debug) cat(paste("s_by_dim", paste(part$s_by_dim, collapse=" "), "\n"))
      obj_ret = obj_fn(y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est=N_est, partition=part, debug=debug, 
                       est_plan=est_plan, sample="trcv")
      oos_obj_val = obj_ret[1]
      partition_oos_cache[partition_i] = oos_obj_val
    }
    lambda_oos[lambda_i] = partition_oos_cache[partition_i]
  }
  return(lambda_oos)
}

lambda_1se_selector <- function(potential_lambdas, lambda_oos, min_obs_1se, max_oos_err_allowed, verbosity) {
  n_lambda = ncol(lambda_oos)
  nfolds = nrow(lambda_oos)
  lambda_oos_means = colMeans(lambda_oos)
  lambda_oos_min_i = which.min(lambda_oos_means)
  #lambda_oos_sd = apply(X=lambda_oos, MARGIN=2, FUN=sd) #don't need all for now
  obs = lambda_oos[, lambda_oos_min_i]
  if(min_obs_1se>nfolds) {
    for(delta in 1:min(n_lambda-lambda_oos_min_i, lambda_oos_min_i-1)) {
      if(lambda_oos_min_i+delta<=n_lambda) {
        obs = c(obs, lambda_oos[, lambda_oos_min_i+delta])
      }
      if(lambda_oos_min_i-delta>=1) {
        obs = c(obs, lambda_oos[, lambda_oos_min_i-delta])
      }
      if(length(obs)>=min_obs_1se) break
    }
    print(obs)
  }
  max_oos_err_allowed = min(lambda_oos_means) + sd(obs)
  if(verbosity>0) cat(paste("max_oos_err_allowed:", paste(max_oos_err_allowed, collapse=" "), "\n"))
  lambda_star_i = min(which(lambda_oos_means <= max_oos_err_allowed))
  lambda_star = potential_lambdas[lambda_star_i]
  
  return(lambda_star)
}

# cv_tr_min_size: We don't want this too large as (since we have less data) otherwise we might not find the 
#                 MSE-min lambda and if the most detailed partition on the full data is best we might have a 
#                 lambda too large and choose one coarser. could choose 2
#                 On the other hand the types of partitions we generate when this param is too small will be different 
#                 and incomparable. Could choose (nfolds-2)/nfolds
#                 Therefore I take the average of the above two approaches.
# Used to warn if best lamda was the smallest, but since there's not much to do about it (we already scale 
# cv_tr_min_size), stopped reporting
# Note: We do not want to first fit the full grid and then take potential lambdas as one from each segment that 
#       picks another grid. Those lambdas aren't gauranteed to include the true lambda min. We basically 
#       roughly sampling the true  CV lambda function (which is is a step-function) and we might miss it and 
#       wrongly evaluate the benefit of each subgrid and therefore  pick the wrong one.
# ... params sent to cv_pick_lambda_f
# use cv_obj_fn if you want to a different obj fun for cv eval (rather than tr,tr training)
cv_pick_lambda <- function(y, X, d, folds_ret, nfolds, potential_lambdas=NULL, N_est=NA, min_size=5, verbosity=0, lambda_1se=FALSE, 
                           min_obs_1se=5, obj_fn, cv_tr_min_size=NA, est_plan, pr_cl=NULL, cv_bump_samples=0, bump_ratio=1, cv_obj_fn=NULL, ...) {
  #If potential_lambdas is NULL, then only have to iterate through lambda values that change partition_i (for any fold)
  #If is_obj_val_seq is monotonic then this is easy and can do sequentially, but not sure if this is the case
  supplied_lambda = !is.null(potential_lambdas)
  if(supplied_lambda) {
    n_lambda = length(potential_lambdas)
    lambda_oos = matrix(NA, nrow=nfolds, ncol=n_lambda)
  }
  else {
    lambda_ties = list()
    cvtr_fits = list()
  }
  if(verbosity>0) cat("Grid > CV: Started.\n")
  if(is.na(cv_tr_min_size)) cv_tr_min_size = as.integer(ifelse(nfolds==2, (2+min_size/2)/2, (nfolds-2)/nfolds)*min_size)
  
  params = c(list(y=y, X_d=X, d=d, folds_ret=folds_ret, nfolds=nfolds, potential_lambdas=potential_lambdas, 
                N_est=N_est, verbosity=verbosity-1,
                obj_fn=obj_fn, cv_tr_min_size=cv_tr_min_size, 
                est_plan=est_plan, cv_bump_samples=cv_bump_samples, bump_ratio=bump_ratio), list(...))
  
  col_rets = my_apply(1:nfolds, cv_pick_lambda_f, verbosity==1 || !is.null(pr_cl), pr_cl, params)
  do_bump = (length(cv_bump_samples)>1 || cv_bump_samples>0)
  
  if(is.null(cv_obj_fn)) cv_obj_fn = obj_fn
  
  # Process nfolds loop
  if(!supplied_lambda) {
    for(f in 1:nfolds) {
      cvtr_fits[[f]] = col_rets[[f]]
      lambda_ties[[f]] = get_all_lambda_ties(cvtr_fits[[f]], do_bump)  #build lambdas. Assuming no slope ties
    }
  }
  else {
    for(f in 1:nfolds) {
      lambda_oos[f, ] = col_rets[[f]]
    }
    n_cell_table = NULL
  }
  
  if(!supplied_lambda) {
    union_lambda_ties = sort(unlist(lambda_ties), decreasing=TRUE)
    mid_points = union_lambda_ties[-length(union_lambda_ties)] + diff(union_lambda_ties)/2
    potential_lambdas = c(union_lambda_ties[1]+1, mid_points, mid_points[length(mid_points)]/2)
    if(length(potential_lambdas)==0) {
      if(verbosity>0) cat("Note: CV folds consistently picked initial model (complexity didn't improve in-sample objective). Defaulting to lambda=0.\n")
      potential_lambdas=c(0)
    }
    n_lambda = length(potential_lambdas)
    lambda_oos = matrix(NA, nrow=nfolds, ncol=n_lambda)
    n_cell_table = matrix(NA, nrow=nfolds, ncol=n_lambda)
    
    for(f in 1:nfolds) {
      list[y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv] = split_sample_folds_m(y, X, d, folds_ret, f)
      lambda_oos[f,] = eval_lambdas(cv_obj_fn, est_plan, potential_lambdas, cvtr_fits[[f]], y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est, do_bump)
      if(FALSE) {
        #ns = get_num_parts(cvtr_fits[[f]])
        
        for(lambda_i in 1:length(potential_lambdas)) {
          lambda = potential_lambdas[lambda_i]
          list[partition_i, part] = get_part_for_lambda(cvtr_fits[[f]], lambda)
          n_cell_table[f, lambda_i] = num_cells(part)
          #print(paste("fit=good. lambda: ", lambda))
          #good_obj_ret = cv_obj_fn(y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est=N_est, partition=part, debug=TRUE, 
          #                 est_plan=est_plan, sample="trcv")
        }
      }
    }
  }
  lambda_oos_means = colMeans(lambda_oos)
  lambda_oos_min_i = which.min(lambda_oos_means)
  #if(lambda_oos_min_i==length(lambda_oos_means)) cat(paste("Warning: MSE-min lambda is the smallest (of",length(lambda_oos_means),"potential lambdas)\n"))
  if(lambda_1se) {
    lambda_star = lambda_1se_selector(potential_lambdas, lambda_oos, min_obs_1se, max_oos_err_allowed, verbosity)
  }
  else {
    lambda_star = potential_lambdas[lambda_oos_min_i]
  }
  if(verbosity>0){ 
    cat("Grid > CV: Finished.")
    cat(paste(" lambda_oos_means=[", paste(lambda_oos_means, collapse=" "), "]."))
    if(length(lambda_star)==0) cat(" Couldn't find any suitable lambdas, returning 1.\n")
    else {
      cat(paste(" potential_lambdas=[", paste(potential_lambdas, collapse=" "), "]."))
      cat(paste(" lambda_star=", lambda_star, ".\n"))
    }
  }
  if(length(lambda_star)==0) {
    lambda_star=1
  }
  
  return(list(lambda_star,lambda_oos, n_cell_table))
}
